{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "SOURCE_DIR = os.path.dirname(os.path.abspath(__name__))\n",
    "sys.path.insert(0, f\"{SOURCE_DIR}/src\")\n",
    "sys.path.append(SOURCE_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import XLNetTokenizer\n",
    "tokenizer = XLNetTokenizer.from_pretrained(\n",
    "    'huseinzol05/xlnet-base-bahasa-cased', do_lower_case = False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('vocab-xlnet-base.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "    \n",
    "LABEL_VOCAB = data['label']\n",
    "TAG_VOCAB = data['tag']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.gfile.GFile('export/xlnet-base.pb.quantized', 'rb') as f:\n",
    "    graph_def = tf.GraphDef()\n",
    "    graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = graph.get_tensor_by_name('import/input_ids:0')\n",
    "word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')\n",
    "charts = graph.get_tensor_by_name('import/charts:0')\n",
    "tags = graph.get_tensor_by_name('import/tags:0')\n",
    "sess = tf.InteractiveSession(graph = graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_MAX_LEN = 512\n",
    "import numpy as np\n",
    "from parse_nk import BERT_TOKEN_MAPPING\n",
    "\n",
    "def make_feed_dict_bert(sentences):\n",
    "    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)\n",
    "    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)\n",
    "\n",
    "    subword_max_len = 0\n",
    "    for snum, sentence in enumerate(sentences):\n",
    "        tokens = []\n",
    "        word_end_mask = []\n",
    "\n",
    "        cleaned_words = []\n",
    "        for word in sentence:\n",
    "            word = BERT_TOKEN_MAPPING.get(word, word)\n",
    "            if word == \"n't\" and cleaned_words:\n",
    "                cleaned_words[-1] = cleaned_words[-1] + \"n\"\n",
    "                word = \"'t\"\n",
    "            cleaned_words.append(word)\n",
    "\n",
    "        for word in cleaned_words:\n",
    "            word_tokens = tokenizer.tokenize(word)\n",
    "            for _ in range(len(word_tokens)):\n",
    "                word_end_mask.append(0)\n",
    "            word_end_mask[-1] = 1\n",
    "            tokens.extend(word_tokens)\n",
    "        tokens.append(\"<sep>\")\n",
    "        word_end_mask.append(1)\n",
    "        tokens.append(\"<cls>\")\n",
    "        word_end_mask.append(1)\n",
    "\n",
    "        input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "        input_mask = [1] * len(input_ids)\n",
    "\n",
    "        subword_max_len = max(subword_max_len, len(input_ids))\n",
    "\n",
    "        all_input_ids[snum, :len(input_ids)] = input_ids\n",
    "        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask\n",
    "\n",
    "    all_input_ids = all_input_ids[:, :subword_max_len]\n",
    "    all_word_end_mask = all_word_end_mask[:, :subword_max_len]\n",
    "    return all_input_ids, all_word_end_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[  383,  1096, 21767,    88,   757,  1606, 15738,    24,   198,\n",
       "          4049,  2479,  7529,   271,  7644,     9,     4,     3]]),\n",
       " array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = 'Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar sekiranya mengantuk ketika memandu.'.split()\n",
    "sentences = [s]\n",
    "i, m = make_feed_dict_bert(sentences)\n",
    "i, m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[[[ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,\n",
       "           -3.2888124 , -2.3206522 ],\n",
       "          [ 0.        , -3.027123  , -4.226623  , ..., -2.8357825 ,\n",
       "           -1.8898286 , -1.4015231 ],\n",
       "          [ 0.        , -4.4751773 , -4.0391564 , ..., -2.0557482 ,\n",
       "           -2.5116968 , -2.3222225 ],\n",
       "          ...,\n",
       "          [ 0.        , -3.0751634 , -5.0968566 , ..., -2.8086288 ,\n",
       "           -2.4999511 , -2.1263318 ],\n",
       "          [ 0.        , -0.21056579, -3.830736  , ..., -2.8927965 ,\n",
       "           -3.2552593 , -2.4863706 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        , -2.7311478 , -3.266908  , ..., -2.5407348 ,\n",
       "           -2.7771876 , -3.4675407 ],\n",
       "          [ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,\n",
       "           -3.2888124 , -2.3206522 ],\n",
       "          [ 0.        , -3.5888338 , -3.3573806 , ..., -1.5800554 ,\n",
       "           -3.0464988 , -3.0329437 ],\n",
       "          ...,\n",
       "          [ 0.        , -3.2029173 , -4.5130672 , ..., -2.2285173 ,\n",
       "           -2.8037972 , -3.142217  ],\n",
       "          [ 0.        , -1.1843511 , -3.4454162 , ..., -3.100471  ,\n",
       "           -3.7265048 , -3.341656  ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        , -4.2309494 , -3.0517614 , ..., -2.8432944 ,\n",
       "           -2.354141  , -3.4422119 ],\n",
       "          [ 0.        , -3.546693  , -3.1645987 , ..., -2.6389394 ,\n",
       "           -2.0354059 , -2.1726856 ],\n",
       "          [ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,\n",
       "           -3.2888124 , -2.3206522 ],\n",
       "          ...,\n",
       "          [ 0.        , -2.4065578 , -3.6648293 , ..., -3.1503863 ,\n",
       "           -2.414618  , -2.4223468 ],\n",
       "          [ 0.        , -2.0922737 , -4.2768326 , ..., -3.038235  ,\n",
       "           -3.219317  , -2.7991467 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 0.        , -5.3908677 , -2.167445  , ..., -2.2608237 ,\n",
       "           -2.2742264 , -3.3628154 ],\n",
       "          [ 0.        , -5.3397794 , -2.6094348 , ..., -2.052136  ,\n",
       "           -1.9883934 , -2.5381312 ],\n",
       "          [ 0.        , -4.9658446 , -1.8049864 , ..., -1.8930271 ,\n",
       "           -2.3911085 , -2.587564  ],\n",
       "          ...,\n",
       "          [ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,\n",
       "           -3.2888124 , -2.3206522 ],\n",
       "          [ 0.        , -3.2234542 , -3.848924  , ..., -2.860902  ,\n",
       "           -3.207498  , -3.13295   ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        , -0.9576549 , -0.46345818, ..., -1.1357111 ,\n",
       "           -1.4234623 , -3.7307076 ],\n",
       "          [ 0.        , -2.2830775 , -0.40777344, ..., -1.3792486 ,\n",
       "           -1.1460778 , -3.3147063 ],\n",
       "          [ 0.        , -3.1584206 , -1.6801515 , ..., -0.82609725,\n",
       "           -1.5552241 , -3.3326428 ],\n",
       "          ...,\n",
       "          [ 0.        , -2.715006  , -2.6279292 , ..., -1.4977499 ,\n",
       "           -1.5926964 , -3.6137059 ],\n",
       "          [ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,\n",
       "           -3.2888124 , -2.3206522 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          ...,\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]]]], dtype=float32),\n",
       " array([[ 0, 14, 13,  9,  4,  3,  9,  9, 14,  9, 15,  3,  6,  6, 12,  1]]))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})\n",
    "charts_val, tags_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "for snum, sentence in enumerate(sentences):\n",
    "    chart_size = len(sentence) + 1\n",
    "    chart = charts_val[snum,:chart_size,:chart_size,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import chart_decoder_py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20.88371,\n",
       " array([ 0,  0,  0,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,\n",
       "         8,  8,  9,  9, 10, 10, 11, 11, 12, 13]),\n",
       " array([14,  2,  1,  2, 14, 13,  3, 13,  4, 13,  5, 13,  6, 13,  7, 13,  8,\n",
       "        13,  9, 13, 10, 13, 11, 13, 12, 13, 14]),\n",
       " array([ 1,  4,  0,  4,  0,  5,  0,  0,  0,  7,  0,  5,  0,  5,  5,  0,  0,\n",
       "         5,  0,  0, 13,  7,  0,  1,  4,  0,  0]))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chart_decoder_py.decode(chart)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "from nltk import Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "PTB_TOKEN_ESCAPE = {u\"(\": u\"-LRB-\",\n",
    "    u\")\": u\"-RRB-\",\n",
    "    u\"{\": u\"-LCB-\",\n",
    "    u\"}\": u\"-RCB-\",\n",
    "    u\"[\": u\"-LSB-\",\n",
    "    u\"]\": u\"-RSB-\"}\n",
    "\n",
    "\n",
    "def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):\n",
    "\n",
    "    # Python 2 doesn't support \"nonlocal\", so wrap idx in a list\n",
    "    idx_cell = [-1]\n",
    "    def make_tree():\n",
    "        idx_cell[0] += 1\n",
    "        idx = idx_cell[0]\n",
    "        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]\n",
    "        label = LABEL_VOCAB[label_idx]\n",
    "        if (i + 1) >= j:\n",
    "            word = sentence[i]\n",
    "            tag = TAG_VOCAB[tags[i]]\n",
    "            tag = PTB_TOKEN_ESCAPE.get(tag, tag)\n",
    "            word = PTB_TOKEN_ESCAPE.get(word, word)\n",
    "            tree = Tree(tag, [word])\n",
    "            for sublabel in label[::-1]:\n",
    "                tree = Tree(sublabel, [tree])\n",
    "            return [tree]\n",
    "        else:\n",
    "            left_trees = make_tree()\n",
    "            right_trees = make_tree()\n",
    "            children = left_trees + right_trees\n",
    "            if label:\n",
    "                tree = Tree(label[-1], children)\n",
    "                for sublabel in reversed(label[:-1]):\n",
    "                    tree = Tree(sublabel, [tree])\n",
    "                return [tree]\n",
    "            else:\n",
    "                return children\n",
    "\n",
    "    tree = make_tree()[0]\n",
    "    tree.score = score\n",
    "    return tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(S\n",
      "  (NP-SBJ (<START> Dr) (NP-SBJ (CC Mahathir)))\n",
      "  (VP\n",
      "    (NNP menasihati)\n",
      "    (VB mereka)\n",
      "    (SBAR\n",
      "      (PRP supaya)\n",
      "      (VP\n",
      "        (IN berhenti)\n",
      "        (VP\n",
      "          (VP (VB berehat))\n",
      "          (VB dan)\n",
      "          (VP\n",
      "            (CC tidur)\n",
      "            (ADVP (VB sebentar))\n",
      "            (SBAR\n",
      "              (RB sekiranya)\n",
      "              (S (NP-SBJ (IN mengantuk)) (NN ketika))))))))\n",
      "  (NN memandu.))\n"
     ]
    }
   ],
   "source": [
    "tree = make_nltk_tree(s, tags_val[0], *chart_decoder_py.decode(chart))\n",
    "print(str(tree))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_str_tree(sentence, tags, score, p_i, p_j, p_label):\n",
    "    idx_cell = [-1]\n",
    "    def make_str():\n",
    "        idx_cell[0] += 1\n",
    "        idx = idx_cell[0]\n",
    "        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]\n",
    "        label = LABEL_VOCAB[label_idx]\n",
    "        if (i + 1) >= j:\n",
    "            word = sentence[i]\n",
    "            tag = TAG_VOCAB[tags[i]]\n",
    "            tag = PTB_TOKEN_ESCAPE.get(tag, tag)\n",
    "            word = PTB_TOKEN_ESCAPE.get(word, word)\n",
    "            s = u\"({} {})\".format(tag, word)\n",
    "        else:\n",
    "            children = []\n",
    "            while ((idx_cell[0] + 1) < len(p_i)\n",
    "                and i <= p_i[idx_cell[0] + 1]\n",
    "                and p_j[idx_cell[0] + 1] <= j):\n",
    "                children.append(make_str())\n",
    "\n",
    "            s = u\" \".join(children)\n",
    "            \n",
    "        for sublabel in reversed(label):\n",
    "            s = u\"({} {})\".format(sublabel, s)\n",
    "        return s\n",
    "    return make_str()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'(S (NP-SBJ (<START> Dr) (NP-SBJ (CC Mahathir))) (VP (NNP menasihati) (VB mereka) (SBAR (PRP supaya) (VP (IN berhenti) (VP (VP (VB berehat)) (VB dan) (VP (CC tidur) (ADVP (VB sebentar)) (SBAR (RB sekiranya) (S (NP-SBJ (IN mengantuk)) (NN ketika)))))))) (NN memandu.))'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_str_tree(s, tags_val[0], *chart_decoder_py.decode(chart))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
